Enable native AMD ROCm inference via Triton kernels#166
Enable native AMD ROCm inference via Triton kernels#166jandom merged 5 commits intoaqlaboratory:mainfrom
Conversation
Adds native AMD ROCm inference support through Triton kernels, including Evoformer attention and TriangleMultiplicativeUpdate inference kernels, along with validation, tests, and AMD-specific install and runtime configuration. - add Triton Evoformer attention with i64 overflow-safe indexing for long sequences - add fused Triton kernels for TriangleMultiplicativeUpdate inference - thread `use_triton_triangle_kernels` through the model stack, matching the existing use_deepspeed_evo_attention pattern - validate ROCm backend execution - add kernel tests covering forward and backward correctness in bf16 and fp32 - add ready-to-use AMD runner and environment configs - add optional installation support via `pip install openfold3[rocm]`
jnwei
left a comment
There was a problem hiding this comment.
Thank you very much for this contribution @singagan ! It is very thorough complete with tests. I know several users are looking forward to this addtiion.
I have 2 questions and a few review requests relating to documentation
Question 1: Installation
I understand there is a challenge with the rocm installation because the RocM wheels are only available from the version of pytorch provided through the extra index, e.g. pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2.
In other words, a single line installation for openfold3 and rocm support is not possible at this time.
Pinging @sdvillal and @Emrys-Merlin to see if they have thoughts on an elegant solution for how to handle the extra dependencies. We do not need to add this solution to this PR, but it would be good to think about what we'd like to aim for as a best outcome.
Question 2: Kernel Compilation
I see that the evoformer kernel appears to some caching and compilation behavior (gated by tl.heuristcs) to optimize the kernel sizes.
A few questions about this behavior:
- In practice, a user of OpenFold may submit queries of many different sequences with different sequence length. In this case, is compilation and caching performed for each unique sequence length? If so, how much extra time to these compilation steps add to the overall worfklow?
- Is there a way to skip the compilation step all together if the user prefers to skip these steps?
I have limited experience with triton kernels compilations, so please feel free to correct any misunderstandings / refer to me documentation about compilation.
Documentation requests:
Thank you for providing examples and comments for how to run with RocM. To make the AMD compatibility modes more visible, it would be best to add these instructions to the main documentation.
In particular:
-
The installation documentation, with the custom pip install pytorch command, may be added here.
-
The selection of the inference mode with the triton kernels could be added to the inference document, as one of the inference modes, in this section (readthedocs) (raw)
pyproject.toml:
Since the option pip install openfold3[rocm] doesn't add any relevant dependencies, I think we should remove the option from pyproject.toml for now.
One reason we might keep the option is if we can add validation logic to the installation option, to check if pytorch was installed using the extra rocm indices. Perhaps @sdvillal and @Emrys-Merlin will also have thoughts here.
|
Tanks for pinging us. @sdvillall and I had a look at the PR and we think that we should be able to add a ROCm-specific environment (e.g., called pixi run -e openfold3-rocm run_openfold ...
pixi run -e openfold3-cuda12 run_openfold ...
pixi run -e openfold3-cuda13 run_openfold ...
pixi run -e opnefold3-cpu run_openfold ...Most of the groundwork should already be there. It's a pity that ROCm is not yet on conda-forge, but we should be able to get everything we need from PyPI (with the ROCm-PyTorch-index). I will try to cherrypick this PR ontop of the pixi beta branch and see if I can get a working environment. However, I'm not sure I will get around to it this week. Also, I might need help testing as I don't have direct access to AMD accelerators. I will come back to you if that becomes a blocker. Let me know if that goes in the direction you were thinking @jnwei. |
- document ROCm install steps and Triton inference mode in Installation.md and inference.md - add validate-openfold3-rocm console script to verify ROCm environment after install - remove empty openfold3[rocm] pip extra in favour of plain pip install openfold3
d06f8c5 to
c7dd1cc
Compare
|
Thank you for the detailed review @jnwei.
Triton JIT-compiles a separate kernel variant for each unique sequence length, since Triton caches compiled kernels to disk ( Regarding
Not directly. JIT compilation is intrinsic to how Triton generates GPU-native code and cannot be bypassed. However, since the compiled kernels are cached to disk, the cost is paid once per unique sequence length per machine and never again. Users who need predictable latency from the very first query can pre-warm the cache by running a short dummy forward pass at each expected sequence length before submitting real queries. For any repeated workload, the compilation overhead is zero.
Earlier, I was planning to keep the option as a hook for future PyTorch wheel support, but I decided to remove the empty I agree with @Emrys-Merlin that adding a pixi environment in a follow-up PR would be a cleaner solution. @Emrys-Merlin, happy to help with testing and verification whenever you need access to AMD hardware. |
|
@singagan Thank you for the explanations regarding the compilation and caching of the kernels. I see that the kernel caching will speed up inference long term, so long as the triton cache remains. I wonder if there is a recommended way for working with triton caches for users who may have transient access to compute, for example, if a user runs primarily with AWS instances. @Emrys-Merlin , your suggestions the specifc pixi environment setups sound great! I think that is a great direction to aim for, we can aim to have a follow up PR that includes RocM. @singagan I have now had the chance to run the full set of unit tests on an AMD GPU (MI210), and I observe a few failures. I am not sure if you have observed these errors in your testing. I had followed the instructions added to the Installation documentation, and verified that Deepspeed installation / model configuration issue ( Full Error messageRuntimeError: Unable to JIT load the evoformer_attn op due to it not being compatible due to hardware/software issue. NoneFull list of failuresopenfold3/tests/test_kernels.py::TestKernels::test_compare_diffusion_transformer_dsk_bf16 openfold3/tests/test_kernels.py::TestKernels::test_compare_diffusion_transformer_dsk_fp32 openfold3/tests/test_kernels.py::TestKernels::test_compare_pairformer_dsk_bf16 openfold3/tests/test_kernels.py::TestKernels::test_compare_pairformer_dsk_fp32 openfold3/tests/test_kernels.py::TestKernels::test_compare_pairformer_dsk_fp32_chunk openfold3/tests/test_kernels.py::TestKernels::test_compare_template_stack_dsk_bf16 openfold3/tests/test_kernels.py::TestKernels::test_compare_template_stack_dsk_fp32 openfold3/tests/test_kernels.py::TestKernels::test_compare_template_stack_dsk_fp32_chunk openfold3/tests/test_kernels.py::TestKernels::test_dsk_backward_bf16 openfold3/tests/test_kernels.py::TestKernels::test_dsk_backward_fp32 openfold3/tests/test_kernels.py::TestKernels::test_dsk_forward_bf16 openfold3/tests/test_kernels.py::TestKernels::test_dsk_forward_fp32 openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=train-dtype=torch.float32] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=train-dtype=torch.bfloat16] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=eval-dtype=torch.float32] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=eval-dtype=torch.bfloat16] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_large_eval[dtype=torch.float32] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_large_eval[dtype=torch.bfloat16]
Example preset additiontriton: settings; memory: eval: use_triton_triangle_kernels: true use_deepspeed_evo_attention: false use_cueq_triangle_kernels: falseThen in test_of3_model.py, the preset can be applied with And to run inference with the triton kernels, the triton.yml simplifies to LMDB cache writing error ( Full errortest_lmdb.py: lmdb.Error: The environment '/tmp/pytest-of-jwei22/pytest-1/test_lmdb_roundtrip0/test_lmdb' is already open in this process. |
…tests - add skip_if_rocm() decorator in compare_utils.py - update skip_unless_ds4s_installed() to also skip on ROCm/HIP - add use_triton_triangle_kernels param to run_model in test_of3_model.py - test_shape_small_kernels, test_shape_large_eval, test_shape_large_bf16_train now use Triton kernels on ROCm instead of failing on DeepSpeed
|
@jnwei Thank you for the detailed report. I have pushed updates for DeepSpeed-related test on ROCm.
I went with the manual configuration approach rather than the model preset to keep the changes minimal for now. LMDB cache writing error (test_lmdb.py): Observed this as well. Good to know this is a known issue tied to Python version rather than the AMD build. Happy to defer to the separate PR. |
|
Thank you for adding the configuration changes for test_kernels.py and test_of3_models.py I see that test_kernels.py on AMD now correctly skips all the tests except the triton tests. For the test_of3_models.py unit tests, I observe the following errors on all of the tests with the recent additions. My read of the error is that the Linear layer is not configured the fused triton linear modules that are specified in the PR description. Did you encounter this error? FAILED openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_fp32[model=train] - RuntimeError: CUDA error: HIPBLAS_STATUS_INTERNAL_ERROR when calling hipblasLtMatmul with transpose_mat1 1 transpose_ma... Full stack trace: |
…ibrary in test conftest
|
Observed this on specific systems while using hipBLASLt. We switched to rocBLAS, which was already in import_utils.py for inference, but not applied during tests. Added a session-scoped autouse fixture in conftest.py to apply the same setting for the full test suite. |
jnwei
left a comment
There was a problem hiding this comment.
LGTM, with the latest conftest fixture for specifying rocBLAS, I see that the of3_model tests and the kernel tests all pass.
Thank you for your contribution to the OpenFold community @singagan ! I am sure many users will find these additions useful.
Enable native AMD ROCm inference via Triton kernels
Summary
OpenFold3 now supports high-performance native inference on AMD GPUs with ROCm.
OpenFold3 previously lacked a high-performance inference path on AMD GPUs
because Evoformer attention and TriangleMultiplicativeUpdate depended on
CUDA-specific kernels with no ROCm support. This PR unlocks native AMD
inference by replacing those dependencies with Triton kernels and wiring the
new path through the full OpenFold3 model stack.
It also adds ROCm validation, kernel correctness tests, and ready-to-use AMD
installation and runtime configuration, making native AMD inference practical
out of the box.
Changes
in the TriangleMultiplicativeUpdate inference path
use_triton_triangle_kernelsthrough the full model stack, matchingthe existing
use_deepspeed_evo_attention/use_cueq_triangle_kernelspatternexamples/example_runner_yamls/triton.ymlas a ready-to-use runner configpip install openfold3[rocm]environments/production-amd-linux-64.ymlfor AMD conda environmentsRelated Issues
Testing